Data Exploration, Training Models, and XGBoost SHAP Values¶

Description¶

In this page, we present plots and graphs used for data exploration, training machine learning models and exploring the interpretability of the XGBoost model.

InĀ [1]:
# Imports 
import numpy as np
import pandas as pd 
df = pd.read_csv("author_sentiment.csv")
df.head(10)
Out[1]:
TITLE TARGET_ENTITY DOCUMENT TRUE_SENTIMENT text_cleaned text_processed num_uppercase num_first_pronoun num_second_pronoun num_third_pronoun ... white win woman work world would write year york young
0 German bank LBBW wins EU bailout approval Landesbank Baden Wuertemberg Germany's Landesbank Baden Wuertemberg won EU ... Negative Germany's Landesbank Baden Wuertemberg won EU ... Germany/NNP 's/POS Landesbank/NNP Baden/NNP Wu... 2.0 0.0 0.0 8.0 ... 0.0 0.123543 0.000000 0.000000 0.000000 0.232421 0.000000 0.143785 0.0 0.0
1 8th LD Writethru: 9th passenger released from ... Rolando Mendoza The Philippine National Police (PNP) identifie... Neutral The Philippine National Police (PNP) identifie... the/DT Philippine/NNP National/NNP Police/NNP ... 5.0 0.0 0.0 5.0 ... 0.0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.0 0.0
2 Commission: Bar Liberian president from office Charles Taylor Sirleaf 70 acknowledged before the commissio... Negative Sirleaf 70 acknowledged before the commission ... sirleaf/NN 70/CD acknowledge/VBD before/IN the... 0.0 1.0 0.0 16.0 ... 0.0 0.000000 0.000000 0.092898 0.000000 0.148777 0.000000 0.000000 0.0 0.0
3 AP Exclusive: Network flaw causes scary Web error Sawyers Sawyer logged off and asked her sister Mari ... Neutral Sawyer logged off and asked her sister Mari 31... Sawyer/NNP log/VBD off/RP and/CC ask/VBD her/P... 0.0 0.0 0.0 15.0 ... 0.0 0.000000 0.000000 0.000000 0.000000 0.000000 0.179045 0.000000 0.0 0.0
4 Holyfield ' s wife says boxer hit her several ... Candi Holyfield Candi Holyfield said in the protective order t... Neutral Candi Holyfield said in the protective order t... Candi/NNP Holyfield/NNP say/VBD in/IN the/DT p... 0.0 5.0 1.0 17.0 ... 0.0 0.000000 0.000000 0.000000 0.000000 0.113763 0.000000 0.105568 0.0 0.0
5 Hillary Clinton : Misogyny is ` endemic ' . Hillary Clinton -LRB- CNN -RRB- Hillary Clinton slammed what s... Neutral -LRB- CNN -RRB- Hillary Clinton slammed what s... -LRB-/NNP CNN/NNP -RRB-/NNP Hillary/NNP Clinto... 4.0 0.0 0.0 5.0 ... 0.0 0.000000 0.250242 0.000000 0.000000 0.000000 0.000000 0.000000 0.0 0.0
6 Trouser-wearing women fined $200 in Sudan Lubna Hussein Lubna Hussein was among 13 women arrested July... Neutral Lubna Hussein was among 13 women arrested July... Lubna/NNP Hussein/NNP be/VBD among/IN 13/CD wo... 1.0 0.0 0.0 14.0 ... 0.0 0.000000 0.344940 0.000000 0.224427 0.139241 0.000000 0.000000 0.0 0.0
7 Hillary Clinton Compares Donald Trump With Har... Hillary Clinton "A lot of people thought I was probably exagge... Neutral "A lot of people thought I was probably exagge... "/`` A/DT lot/NN of/IN people/NNS think/VBD I/... 0.0 2.0 0.0 3.0 ... 0.0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.0 0.0
8 Feature: "Chinese is an important part of my l... Maria Rukodelnikova Rukodelnikova is fond of a lot things from Chi... Positive Rukodelnikova is fond of a lot things from Chi... Rukodelnikova/NNP be/VBZ fond/JJ of/IN a/DT lo... 0.0 0.0 0.0 6.0 ... 0.0 0.000000 0.000000 0.183372 0.000000 0.000000 0.000000 0.136258 0.0 0.0
9 Former Australian Opposition leader attacks ne... Tony Abbott Former Australian Opposition leader Malcolm Tu... Neutral Former Australian Opposition leader Malcolm Tu... former/JJ Australian/NNP Opposition/NNP leader... 2.0 2.0 0.0 8.0 ... 0.0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.0 0.0

10 rows Ɨ 223 columns

InĀ [2]:
# Configuration
from traitlets.config import Config
import nbformat as nbf
from nbconvert.exporters import HTMLExporter
from nbconvert.preprocessors import TagRemovePreprocessor

# Setup config
c = Config()

# Configure tag removal - be sure to tag your cells to remove  using the
# words remove_cell to remove cells. You can also modify the code to use
# a different tag word
c.TagRemovePreprocessor.remove_cell_tags = ("remove_cell",)
c.TagRemovePreprocessor.remove_all_outputs_tags = ("remove_output",)
c.TagRemovePreprocessor.remove_input_tags = ("remove_input",)
c.TagRemovePreprocessor.enabled = True

# Configure and run out exporter
c.HTMLExporter.preprocessors = ["nbconvert.preprocessors.TagRemovePreprocessor"]

exporter = HTMLExporter(config=c)
exporter.register_preprocessor(TagRemovePreprocessor(config=c), True)

# Configure and run our exporter - returns a tuple - first element with html,
# second with notebook metadata
output = HTMLExporter(config=c).from_filename("ml_models.ipynb")

# Write to output html file
with open("ml_models.html", "w") as f:
    f.write(output[0])
InĀ [3]:
from sklearn.preprocessing import LabelEncoder
label_encoder = LabelEncoder()
df['label'] = label_encoder.fit_transform(df['TRUE_SENTIMENT'].values) #neg = 0, #neu = 1, pos=2
# print(df.head(5))

# Fix class imbalance
class_count_0, class_count_1, class_count_2 = df['label'].value_counts()

class_0 = df[df['label'] == 0] # neg
class_1 = df[df['label'] == 1] # neu
class_2 = df[df['label'] == 2] # pos
print('class 0:', class_0.shape)
print('class 1:', class_1.shape)
print('class 2:', class_2.shape)

print(class_count_2)
class_0_over = class_0.sample(class_count_1, replace=True, ignore_index=True)
class_2_under = class_2.sample(class_count_1, ignore_index=True)

new_df = pd.concat([class_1, class_0_over, class_2_under], axis=0, ignore_index=True)

print("total class of 0, 1 and 2:", new_df['label'].value_counts()) # plot the count after under-sampeling
new_df['label'].value_counts().plot(kind='bar', title='Number of each class')
class 0: (351, 224)
class 1: (1246, 224)
class 2: (1758, 224)
351
total class of 0, 1 and 2: label
1    1246
0    1246
2    1246
Name: count, dtype: int64
Out[3]:
<Axes: title={'center': 'Number of each class'}, xlabel='label'>
No description has been provided for this image

Correlation Heat Map for Features 1 - 17 and Target Label¶

InĀ [4]:
# Correlation matrix
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

# Calculate the correlation matrix
correlation_matrix = new_df.iloc[:, [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, -1]].corr()

# Plot the heatmap
plt.figure(figsize=(15, 15))
sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', fmt=".2f", annot_kws={"size": 10})
plt.title('Correlation Heatmap')
plt.show()
No description has been provided for this image

Scatter Plot of Average Length of Token by Each Sentiment Class¶

InĀ [5]:
# Visualization I
# !pip install plotly
import plotly.express as px 
 
# plotting the bubble chart
fig = px.scatter(new_df, x="TRUE_SENTIMENT", y="avg_len_token", 
                 size= "avg_len_sen", color="TRUE_SENTIMENT") 
fig.update_layout(title='Scatter plot of average length of token by each sentiment class',
                  xaxis_title='Sentiment Class',
                  yaxis_title='Average Length Token')
 
# showing the plot
fig.show()

Scatter Plot of Number of Proper Nouns by Each Sentiment Class¶

InĀ [6]:
# Visualization II
!pip install plotly
import plotly.express as px 
 
# plotting the bubble chart
fig = px.scatter(new_df, x="TRUE_SENTIMENT", y="num_proper_noun", 
                 size= "avg_len_sen", color="TRUE_SENTIMENT") 
fig.update_layout(title='Scatter plot of number of proper nouns by each sentiment class',
                  xaxis_title='Sentiment Class',
                  yaxis_title='Number of Proper Nouns')
 
 
# showing the plot
fig.show()
Requirement already satisfied: plotly in /opt/conda/lib/python3.11/site-packages (5.20.0)
Requirement already satisfied: tenacity>=6.2.0 in /opt/conda/lib/python3.11/site-packages (from plotly) (8.2.3)
Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from plotly) (23.2)

Scatter Plot of Average Length of Sentence by Each Sentiment Class¶

InĀ [7]:
# Visualization III
!pip install plotly
import plotly.express as px 
 
# plotting the bubble chart
fig = px.scatter(new_df, x="TRUE_SENTIMENT", y="num_past_verb", 
                 size= "avg_len_sen", color="TRUE_SENTIMENT") 
fig.update_layout(title='Scatter plot of number of past-tense verb by each sentiment class',
                  xaxis_title='Sentiment Class',
                  yaxis_title='Number of Past-Tense Verb')
 
 
# showing the plot
fig.show()
Requirement already satisfied: plotly in /opt/conda/lib/python3.11/site-packages (5.20.0)
Requirement already satisfied: tenacity>=6.2.0 in /opt/conda/lib/python3.11/site-packages (from plotly) (8.2.3)
Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from plotly) (23.2)
InĀ [8]:
# Separate feat and label - data process
feat = new_df.iloc[:, 6:-2] 
feat_names = list(feat.columns)
label = new_df.loc[:, ['label']]

print(new_df)
print(feat.shape)
print(label.shape)
print(feat_names)
                                                  TITLE      TARGET_ENTITY  \
0     8th LD Writethru: 9th passenger released from ...    Rolando Mendoza   
1     AP Exclusive: Network flaw causes scary Web error            Sawyers   
2     Holyfield ' s wife says boxer hit her several ...    Candi Holyfield   
3          Hillary Clinton : Misogyny is ` endemic ' .     Hillary Clinton   
4             Trouser-wearing women fined $200 in Sudan      Lubna Hussein   
...                                                 ...                ...   
3733  Thousands travel to Yangon for Pope's diplomat...       Pope Francis   
3734  Er the Oscars' In Memoriam section showed a wo...        Jan Chapman   
3735            Big cat: Jack needs new home and a diet          Jack Jack   
3736  Woman accuses George H.W. Bush of groping her ...         H. W. Bush   
3737  Interview: Language cornerstone of Russia-Chin...  Elizabeth Pavlova   

                                               DOCUMENT TRUE_SENTIMENT  \
0     The Philippine National Police (PNP) identifie...        Neutral   
1     Sawyer logged off and asked her sister  Mari  ...        Neutral   
2     Candi Holyfield said in the protective order t...        Neutral   
3     -LRB- CNN -RRB- Hillary Clinton slammed what s...        Neutral   
4     Lubna Hussein was among 13 women arrested July...        Neutral   
...                                                 ...            ...   
3733  YANGON (Reuters) - Thousands of Catholics gath...       Positive   
3734  Advertisement - Continue Reading Below\nAustra...       Positive   
3735  Skip in Skip x Embed x Share CLOSE  Jack  the ...       Positive   
3736  Another woman has come forward to accuse  form...       Positive   
3737  "It's important for Russia and China to cultiv...       Positive   

                                           text_cleaned  \
0     The Philippine National Police (PNP) identifie...   
1     Sawyer logged off and asked her sister Mari 31...   
2     Candi Holyfield said in the protective order t...   
3     -LRB- CNN -RRB- Hillary Clinton slammed what s...   
4     Lubna Hussein was among 13 women arrested July...   
...                                                 ...   
3733  YANGON (Reuters) - Thousands of Catholics gath...   
3734  Advertisement - Continue Reading Below Austral...   
3735  Skip in Skip x Embed x Share CLOSE Jack the 30...   
3736  Another woman has come forward to accuse forme...   
3737  "It's important for Russia and China to cultiv...   

                                         text_processed  num_uppercase  \
0     the/DT Philippine/NNP National/NNP Police/NNP ...            5.0   
1     Sawyer/NNP log/VBD off/RP and/CC ask/VBD her/P...            0.0   
2     Candi/NNP Holyfield/NNP say/VBD in/IN the/DT p...            0.0   
3     -LRB-/NNP CNN/NNP -RRB-/NNP Hillary/NNP Clinto...            4.0   
4     Lubna/NNP Hussein/NNP be/VBD among/IN 13/CD wo...            1.0   
...                                                 ...            ...   
3733  YANGON/NNP (/-LRB- Reuters/NNP )/-RRB- -/: tho...            8.0   
3734  advertisement/NN -/: continue/VB read/VBG belo...            0.0   
3735  skip/VB in/IN Skip/NNP x/SYM Embed/NNP x/NN \n...            2.0   
3736  another/DT woman/NN have/VBZ come/VBN forward/...            3.0   
3737  "/`` it/PRP be/VBZ important/JJ for/IN Russia/...            0.0   

      num_first_pronoun  num_second_pronoun  num_third_pronoun  ...  win  \
0                   0.0                 0.0                5.0  ...  0.0   
1                   0.0                 0.0               15.0  ...  0.0   
2                   5.0                 1.0               17.0  ...  0.0   
3                   0.0                 0.0                5.0  ...  0.0   
4                   0.0                 0.0               14.0  ...  0.0   
...                 ...                 ...                ...  ...  ...   
3733                0.0                 0.0               18.0  ...  0.0   
3734                5.0                 0.0                8.0  ...  0.0   
3735                0.0                 0.0                6.0  ...  0.0   
3736                2.0                 0.0               27.0  ...  0.0   
3737                0.0                 0.0                8.0  ...  0.0   

         woman      work     world     would     write      year  york  \
0     0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   0.0   
1     0.000000  0.000000  0.000000  0.000000  0.179045  0.000000   0.0   
2     0.000000  0.000000  0.000000  0.113763  0.000000  0.105568   0.0   
3     0.250242  0.000000  0.000000  0.000000  0.000000  0.000000   0.0   
4     0.344940  0.000000  0.224427  0.139241  0.000000  0.000000   0.0   
...        ...       ...       ...       ...       ...       ...   ...   
3733  0.133844  0.000000  0.000000  0.081042  0.000000  0.000000   0.0   
3734  0.000000  0.312165  0.000000  0.000000  0.000000  0.000000   0.0   
3735  0.000000  0.000000  0.000000  0.107136  0.000000  0.198836   0.0   
3736  0.388910  0.000000  0.000000  0.000000  0.000000  0.000000   0.0   
3737  0.000000  0.000000  0.000000  0.166829  0.087332  0.361226   0.0   

         young  label  
0     0.000000      1  
1     0.000000      1  
2     0.000000      1  
3     0.000000      1  
4     0.000000      1  
...        ...    ...  
3733  0.000000      2  
3734  0.000000      2  
3735  0.000000      2  
3736  0.000000      2  
3737  0.101125      2  

[3738 rows x 224 columns]
(3738, 216)
(3738, 1)
['num_uppercase', 'num_first_pronoun', 'num_second_pronoun', 'num_third_pronoun', 'num_coord_conj', 'num_past_verb', 'num_future_verb', 'num_comma', 'num_multi_punc', 'num_common_noun', 'num_proper_noun', 'num_adverb', 'num_wh', 'num_slang', 'avg_len_sen', 'avg_len_token', 'num_sen', '000', '10', '2016', '2017', 'accord', 'add', 'allegation', 'also', 'american', 'another', 'around', 'ask', 'back', 'become', 'begin', 'believe', 'big', 'bill', 'bush', 'business', 'call', 'campaign', 'case', 'change', 'charge', 'child', 'city', 'claim', 'clinton', 'close', 'come', 'company', 'continue', 'could', 'country', 'court', 'cruz', 'day', 'deal', 'donald', 'early', 'election', 'end', 'even', 'face', 'family', 'far', 'feel', 'find', 'fire', 'first', 'follow', 'force', 'former', 'four', 'game', 'get', 'give', 'go', 'good', 'government', 'great', 'group', 'head', 'help', 'high', 'hold', 'home', 'house', 'include', 'interview', 'issue', 'itâ', 'james', 'job', 'keep', 'know', 'last', 'late', 'later', 'law', 'lead', 'leader', 'leave', 'life', 'like', 'live', 'long', 'look', 'lose', 'lot', 'lsb', 'make', 'man', 'many', 'may', 'medium', 'meet', 'meeting', 'member', 'million', 'monday', 'month', 'move', 'much', 'name', 'national', 'need', 'never', 'new', 'news', 'next', 'night', 'north', 'office', 'official', 'old', 'one', 'open', 'part', 'party', 'pay', 'people', 'percent', 'photo', 'place', 'plan', 'play', 'point', 'police', 'political', 'post', 'president', 'public', 'put', 'question', 'really', 'release', 'report', 'republican', 'return', 'reuters', 'right', 'romney', 'rsb', 'run', 'sanders', 'say', 'school', 'season', 'second', 'see', 'senate', 'set', 'sexual', 'share', 'show', 'since', 'speak', 'start', 'state', 'statement', 'still', 'story', 'support', 'take', 'talk', 'team', 'tell', 'thing', 'think', 'though', 'three', 'thursday', 'time', 'top', 'trump', 'try', 'tuesday', 'turn', 'two', 'united', 'use', 'vote', 'want', 'washington', 'way', 'wednesday', 'week', 'well', 'white', 'win', 'woman', 'work', 'world', 'would', 'write', 'year', 'york']
InĀ [9]:
# Train-Test split 
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    feat, label, test_size=0.2, random_state=1100)
print(X_train)
print(X_test)
print(y_train)
print(y_test)
      num_uppercase  num_first_pronoun  num_second_pronoun  num_third_pronoun  \
3212            1.0                5.0                 8.0               14.0   
325             0.0                0.0                 0.0                7.0   
1255            0.0                0.0                 0.0                1.0   
454             0.0                0.0                 0.0                7.0   
645             0.0                2.0                 2.0               10.0   
...             ...                ...                 ...                ...   
968             2.0                0.0                 0.0                9.0   
2288            0.0                5.0                 2.0               27.0   
2991            4.0                4.0                 0.0               14.0   
2093            2.0                0.0                 1.0                5.0   
3462            1.0                1.0                 0.0                2.0   

      num_coord_conj  num_past_verb  num_future_verb  num_comma  \
3212            30.0           14.0              2.0        0.0   
325              3.0            6.0              0.0        0.0   
1255             3.0            1.0              0.0        3.0   
454              3.0            6.0              0.0        0.0   
645              4.0            6.0              0.0        0.0   
...              ...            ...              ...        ...   
968              4.0           10.0              0.0        0.0   
2288            14.0           16.0              3.0        0.0   
2991            14.0           24.0              2.0        0.0   
2093             3.0           16.0              0.0        0.0   
3462             4.0            2.0              0.0        0.0   

      num_multi_punc  num_common_noun  ...      well  white       win  \
3212            61.0            183.0  ...  0.071516    0.0  0.000000   
325             11.0             38.0  ...  0.000000    0.0  0.000000   
1255             6.0              8.0  ...  0.000000    0.0  0.271660   
454             11.0             38.0  ...  0.000000    0.0  0.000000   
645             12.0             24.0  ...  0.142114    0.0  0.000000   
...              ...              ...  ...       ...    ...       ...   
968             15.0             43.0  ...  0.000000    0.0  0.000000   
2288            52.0             85.0  ...  0.086881    0.0  0.106739   
2991            49.0             91.0  ...  0.067443    0.0  0.082859   
2093            21.0             57.0  ...  0.000000    0.0  0.000000   
3462             8.0             21.0  ...  0.107220    0.0  0.000000   

         woman     work     world     would     write      year      york  
3212  0.000000  0.00000  0.000000  0.055099  0.000000  0.000000  0.093510  
325   0.000000  0.00000  0.000000  0.000000  0.000000  0.315297  0.000000  
1255  0.000000  0.00000  0.000000  0.000000  0.000000  0.000000  0.000000  
454   0.000000  0.00000  0.000000  0.000000  0.000000  0.315297  0.000000  
645   0.000000  0.00000  0.000000  0.000000  0.000000  0.000000  0.000000  
...        ...      ...       ...       ...       ...       ...       ...  
968   0.000000  0.00000  0.000000  0.415975  0.000000  0.000000  0.000000  
2288  0.110547  0.00000  0.000000  0.066936  0.000000  0.000000  0.000000  
2991  0.000000  0.06489  0.000000  0.000000  0.081601  0.000000  0.088184  
2093  0.000000  0.00000  0.126974  0.000000  0.000000  0.000000  0.000000  
3462  0.000000  0.00000  0.000000  0.165213  0.000000  0.000000  0.000000  

[2990 rows x 216 columns]
      num_uppercase  num_first_pronoun  num_second_pronoun  num_third_pronoun  \
3655            5.0                6.0                 1.0               32.0   
743             1.0                0.0                 0.0               14.0   
445             0.0                5.0                 0.0               10.0   
558             5.0                4.0                 0.0               13.0   
1930            6.0                6.0                 3.0               41.0   
...             ...                ...                 ...                ...   
1263            0.0                1.0                 1.0               30.0   
1749            2.0                0.0                 0.0                6.0   
2754            9.0                4.0                 2.0               17.0   
2080            1.0                0.0                 0.0                1.0   
3255            0.0                6.0                 1.0               11.0   

      num_coord_conj  num_past_verb  num_future_verb  num_comma  \
3655            19.0           13.0              3.0        0.0   
743              5.0           16.0              0.0        0.0   
445             10.0           14.0              1.0        0.0   
558              8.0           25.0              0.0        0.0   
1930            16.0           40.0              2.0        0.0   
...              ...            ...              ...        ...   
1263             7.0           35.0              0.0        0.0   
1749             7.0           17.0              0.0        0.0   
2754             8.0           17.0              1.0        0.0   
2080             1.0            7.0              0.0        6.0   
3255            10.0           20.0              1.0        0.0   

      num_multi_punc  num_common_noun  ...      well    white       win  \
3655            52.0             84.0  ...  0.000000  0.06664  0.000000   
743             17.0             53.0  ...  0.000000  0.00000  0.000000   
445             31.0             71.0  ...  0.000000  0.00000  0.000000   
558             38.0            100.0  ...  0.000000  0.00000  0.000000   
1930           107.0            135.0  ...  0.056953  0.00000  0.000000   
...              ...              ...  ...       ...      ...       ...   
1263            57.0             73.0  ...  0.000000  0.00000  0.000000   
1749            15.0             46.0  ...  0.000000  0.00000  0.000000   
2754            34.0             48.0  ...  0.000000  0.00000  0.000000   
2080            19.0             13.0  ...  0.000000  0.00000  0.000000   
3255            60.0             40.0  ...  0.173591  0.00000  0.319905   

         woman      work  world     would     write      year      york  
3655  0.000000  0.000000    0.0  0.040776  0.000000  0.000000  0.000000  
743   0.000000  0.000000    0.0  0.000000  0.000000  0.000000  0.000000  
445   0.000000  0.000000    0.0  0.130476  0.000000  0.000000  0.000000  
558   0.000000  0.097771    0.0  0.156582  0.000000  0.000000  0.000000  
1930  0.289869  0.054797    0.0  0.043879  0.000000  0.162872  0.000000  
...        ...       ...    ...       ...       ...       ...       ...  
1263  0.108881  0.000000    0.0  0.000000  0.000000  0.244712  0.223773  
1749  0.506746  0.000000    0.0  0.076709  0.000000  0.071183  0.000000  
2754  0.000000  0.083497    0.0  0.000000  0.000000  0.000000  0.000000  
2080  0.000000  0.000000    0.0  0.000000  0.110757  0.000000  0.000000  
3255  0.000000  0.083510    0.0  0.133742  0.000000  0.000000  0.000000  

[748 rows x 216 columns]
      label
3212      2
325       1
1255      0
454       1
645       1
...     ...
968       1
2288      0
2991      2
2093      0
3462      2

[2990 rows x 1 columns]
      label
3655      2
743       1
445       1
558       1
1930      0
...     ...
1263      0
1749      0
2754      2
2080      0
3255      2

[748 rows x 1 columns]
InĀ [10]:
# Train Xgboost - default
!pip install xgboost
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
import xgboost as xgb

# Init classifier
xgb_cl = xgb.XGBClassifier()

# Fit
xgb_cl.fit(X_train, y_train)

# Predict
preds = xgb_cl.predict(X_test)
accuracy = accuracy_score(y_test, preds)
f1 = f1_score(y_test, preds, average='macro')

# Score
print("F1 Score:", f1)
print("Accuracy:", accuracy)
Requirement already satisfied: xgboost in /opt/conda/lib/python3.11/site-packages (2.0.3)
Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from xgboost) (1.24.4)
Requirement already satisfied: scipy in /opt/conda/lib/python3.11/site-packages (from xgboost) (1.11.4)
F1 Score: 0.6538556985903352
Accuracy: 0.6564171122994652
InĀ [11]:
# Grid search for Xgboost
from sklearn.model_selection import GridSearchCV

# Define the hyperparameter grid
param_grid = {
    'max_depth': [17, 20, 22], # 6, 7, 8, 10, 12, 15, 17, 20, 22
    'learning_rate': [0.02, 0.03], # 0.008, 0.009, 0.01, 0.02,0.03
    'subsample': [0.5] # 0.5, 0.6 
}

# Create the XGBoost model object
xgb_model = xgb.XGBClassifier()

# Create the GridSearchCV object
grid_search = GridSearchCV(xgb_model, param_grid, cv=5, scoring='accuracy')

# Fit the GridSearchCV object to the training data
grid_search.fit(X_train, y_train)

# Print the best set of hyperparameters and the corresponding score
print("Best set of hyperparameters: ", grid_search.best_params_)
print("Best score: ", grid_search.best_score_) # 0.6749163879598663
Best set of hyperparameters:  {'learning_rate': 0.03, 'max_depth': 20, 'subsample': 0.5}
Best score:  0.6678929765886288
InĀ [12]:
# Train logistic regression 
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score

# Initialize the logistic regression model
model = LogisticRegression()

# Train the model
model.fit(X_train, y_train)

# Predict on the testing set
y_pred = model.predict(X_test)

# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average='macro')

print("F1 Score:", f1)
print("Accuracy:", accuracy)
/opt/conda/lib/python3.11/site-packages/sklearn/utils/validation.py:1183: DataConversionWarning:

A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().

F1 Score: 0.431597278233198
Accuracy: 0.43716577540106955
/opt/conda/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning:

lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression

InĀ [13]:
# Train Decision Tree
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import f1_score

# Initialize the decision tree classifierd
dt_model = DecisionTreeClassifier()

# Train the model
dt_model.fit(X_train, y_train)

# Predict on the testing set
y_pred = dt_model.predict(X_test)

# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average='macro')

print("F1 Score:", f1)
print("Accuracy:", accuracy)
F1 Score: 0.6109020781073954
Accuracy: 0.6216577540106952
InĀ [14]:
# Save best performing model
from sklearn.metrics import f1_score
model_file_path = 'xgboost_model.bin'
xgb_model = xgb.XGBClassifier(max_depth = 22, learning_rate = 0.02, subsample = 0.5)
xgb_model.fit(X_train, y_train)
preds = xgb_model.predict(X_test)
f1 = f1_score(y_test, preds, average='macro')

print("F1 Score:", f1)
print('Accuracy:', accuracy_score(y_test, preds))

# Save the trained model
xgb_model.save_model(model_file_path)
F1 Score: 0.6805725726632295
Accuracy: 0.6831550802139037
/opt/conda/lib/python3.11/site-packages/xgboost/core.py:160: UserWarning:

[20:43:26] WARNING: /workspace/src/c_api/c_api.cc:1240: Saving into deprecated binary model format, please consider using `json` or `ubj`. Model format will default to JSON in XGBoost 2.2 if not specified.

Sentiment Analysis of ChatGPT's Response¶

InĀ [15]:
# Predict on gpt's response

gpt3_df = pd.read_csv("response_gpt3.csv")
gpt3_df = gpt3_df.iloc[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 18], :-1]

gpt4_df = pd.read_csv("response_gpt4.csv")
gpt4_df = gpt4_df.iloc[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 18], :-1]

preds_gpt3 = xgb_model.predict(gpt3_df) # make prediciton on gpt3's response
preds_gpt4 = xgb_model.predict(gpt4_df) # make prediction on gpt4's response

gpt3_df['predicted_sentiment'] = label_encoder.inverse_transform(preds_gpt3)
gpt4_df['predicted_sentiment'] = label_encoder.inverse_transform(preds_gpt4)

print(gpt3_df)
print(gpt4_df)
    num_uppercase  num_first_pronoun  num_second_pronoun  num_third_pronoun  \
1             0.0                2.0                 0.0                0.0   
2             1.0                0.0                 0.0               12.0   
3             1.0                0.0                 0.0               10.0   
4             0.0                0.0                 0.0                6.0   
5             0.0                1.0                 0.0               11.0   
6             0.0                0.0                 0.0               20.0   
7             2.0                0.0                 0.0               12.0   
8             0.0                0.0                 0.0                3.0   
9             4.0                0.0                 0.0                5.0   
10            0.0                4.0                 0.0                3.0   
17            2.0                0.0                 0.0                4.0   
18            0.0                0.0                 0.0                7.0   

    num_coord_conj  num_past_verb  num_future_verb  num_comma  num_multi_punc  \
1             10.0            2.0              0.0       18.0            37.0   
2              7.0            2.0              3.0       12.0            33.0   
3             13.0            0.0              1.0        9.0            24.0   
4             10.0            2.0              0.0        6.0            24.0   
5              8.0            6.0              0.0       14.0            25.0   
6             13.0            7.0              0.0       10.0            23.0   
7             12.0           10.0              1.0       16.0            29.0   
8              9.0            2.0              0.0        8.0            22.0   
9             12.0            5.0              0.0        9.0            22.0   
10            10.0            0.0              0.0       14.0            30.0   
17             6.0            0.0              0.0        9.0            29.0   
18            10.0            5.0              0.0        7.0            26.0   

    num_common_noun  ...  white  win  woman      work     world     would  \
1              86.0  ...    0.0  0.0    0.0  0.104915  0.000000  0.000000   
2              57.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
3              71.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
4              66.0  ...    0.0  0.0    0.0  0.191973  0.000000  0.000000   
5              56.0  ...    0.0  0.0    0.0  0.129537  0.000000  0.000000   
6              65.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
7              62.0  ...    0.0  0.0    0.0  0.000000  0.165213  0.000000   
8              64.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.099061   
9              88.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
10             71.0  ...    0.0  0.0    0.0  0.000000  0.403363  0.750776   
17             68.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
18             66.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   

    write      year  york  predicted_sentiment  
1     0.0  0.000000   0.0             Negative  
2     0.0  0.000000   0.0              Neutral  
3     0.0  0.000000   0.0             Negative  
4     0.0  0.000000   0.0             Negative  
5     0.0  0.000000   0.0             Negative  
6     0.0  0.000000   0.0             Negative  
7     0.0  0.000000   0.0              Neutral  
8     0.0  0.091925   0.0             Negative  
9     0.0  0.000000   0.0             Negative  
10    0.0  0.000000   0.0             Negative  
17    0.0  0.000000   0.0             Negative  
18    0.0  0.000000   0.0             Negative  

[12 rows x 217 columns]
    num_uppercase  num_first_pronoun  num_second_pronoun  num_third_pronoun  \
1             0.0                0.0                 1.0                4.0   
2             2.0                1.0                 0.0                9.0   
3             2.0                1.0                 0.0                5.0   
4             1.0                4.0                 0.0                9.0   
5             3.0                1.0                 0.0               15.0   
6             0.0                3.0                 0.0               14.0   
7             4.0                0.0                 0.0               17.0   
8             0.0                3.0                 0.0                0.0   
9             2.0                2.0                 0.0                5.0   
10            0.0                4.0                 0.0                8.0   
17            3.0                0.0                 0.0                4.0   
18            1.0                0.0                 0.0               10.0   

    num_coord_conj  num_past_verb  num_future_verb  num_comma  num_multi_punc  \
1              9.0            5.0              0.0       19.0            40.0   
2              2.0            3.0              2.0       12.0            28.0   
3              8.0            2.0              1.0       13.0            29.0   
4             11.0            6.0              0.0        9.0            27.0   
5              6.0           15.0              0.0       15.0            36.0   
6             10.0           18.0              0.0       20.0            48.0   
7             10.0           15.0              1.0       15.0            40.0   
8              5.0           10.0              0.0       15.0            32.0   
9             13.0           10.0              1.0       11.0            25.0   
10            12.0            4.0              0.0       21.0            43.0   
17             6.0            2.0              2.0       10.0            27.0   
18             9.0            6.0              1.0       15.0            36.0   

    num_common_noun  ...  white  win  woman      work     world     would  \
1              78.0  ...    0.0  0.0    0.0  0.105211  0.135790  0.000000   
2              54.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
3              71.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
4              62.0  ...    0.0  0.0    0.0  0.175081  0.000000  0.000000   
5              58.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
6              69.0  ...    0.0  0.0    0.0  0.117525  0.000000  0.000000   
7              48.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
8              60.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
9              65.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
10             73.0  ...    0.0  0.0    0.0  0.473087  0.203529  0.252551   
17             55.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   
18             56.0  ...    0.0  0.0    0.0  0.000000  0.000000  0.000000   

       write      year      york  predicted_sentiment  
1   0.000000  0.000000  0.142979             Positive  
2   0.000000  0.000000  0.000000             Negative  
3   0.000000  0.000000  0.000000              Neutral  
4   0.000000  0.000000  0.000000             Negative  
5   0.000000  0.000000  0.000000             Negative  
6   0.147791  0.174658  0.000000             Positive  
7   0.000000  0.000000  0.000000             Positive  
8   0.000000  0.071425  0.000000             Negative  
9   0.000000  0.000000  0.000000              Neutral  
10  0.000000  0.000000  0.000000             Positive  
17  0.000000  0.000000  0.000000             Negative  
18  0.000000  0.058539  0.000000              Neutral  

[12 rows x 217 columns]

SHAP Analysis of XGBoost¶

InĀ [17]:
# Conduct SHAP analysis
# !pip install shap 
# !pip install xgboost
import shap
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
import xgboost as xgb
# print the JS visualization code to the notebook
shap.initjs()

# model_file_path = 'xgboost_model.bin'
# loaded_model = xgb.Booster()
# loaded_model.load_model(model_file_path)

# preds = loaded_model.predict(X_test)

# Shap values to see the feature importance
explainer = shap.TreeExplainer(xgb_model)
shap_values = explainer.shap_values(X_test, check_additivity=False)
print(shap_values.shape)
# print(explainer.expected_value)

shap.force_plot(explainer.expected_value[0], shap_values[:, :, 0])
No description has been provided for this image
(748, 216, 3)
Out[17]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
InĀ [18]:
shap.summary_plot(shap_values[:, :, 0], features=feat, plot_type="bar") # class: negative
No description has been provided for this image
InĀ [19]:
shap.summary_plot(shap_values[:, :, 1], features=feat, plot_type="bar") # class: neutral
No description has been provided for this image
InĀ [20]:
shap.summary_plot(shap_values[:, :, 2], features=feat, plot_type="bar") # class: positive
No description has been provided for this image
InĀ [21]:
shap.force_plot(explainer.expected_value[1], shap_values[0, :, 1], features=feat.iloc[0, :]) # class: neutral
Out[21]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.